/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file matmul_iterate_controller.h
 * \brief
 */
#ifndef IMPL_MATMUL_MODULES_MATMUL_ITERATOR_CONTROLLER_H
#define IMPL_MATMUL_MODULES_MATMUL_ITERATOR_CONTROLLER_H

#include "../matmul_module.h"
#include "../matmul_param.h"
#include "../../matmul_utils.h"
#include "../feature_trait/matmul_iter_ctrl_cfg.h"

namespace AscendC {
namespace Impl {
namespace Detail {

template <typename IMPL, class A_TYPE, class B_TYPE, const auto& MM_CFG, const auto& ITER_CFG>
class MatmulIterateController
{
    MATMUL_USE_MODULE(CopyCubeInA);
    MATMUL_USE_MODULE(CopyCubeInB);

public:
    __aicore__ inline bool MoveNext()
    {
        if (unlikely(MATMUL_PARAM_VAR.isFirstIter_)) {
            return MoveOnFirstIterate();
        } else {
            if constexpr (IsBasicM(MM_CFG) && IsBasicN(MM_CFG)) {
                MATMUL_MODULE(CopyCubeInA)->Reset();
                MATMUL_MODULE(CopyCubeInB)->Reset();
                return false;
            }
        }
        if constexpr (ITER_CFG.iterOrder == IterateOrder::UNDEF) {
            auto& var = MATMUL_PARAM_VAR;
            if (likely(var.tiling_.GetIterateOrder() == static_cast<int>(IterateOrder::ORDER_M))) {
                return MoveOnIterateOrderM();
            } else {
                ASCENDC_ASSERT((var.tiling_.GetIterateOrder() == static_cast<int>(IterateOrder::ORDER_N)), {
                    KERNEL_LOG(KERNEL_ERROR, "iterateOrder is %d , which should be ORDER_N",
                    var.tiling_.GetIterateOrder());
                });
                return MoveOnIterateOrderN();
            }
        } else if (ITER_CFG.iterOrder == IterateOrder::ORDER_M) {
            return MoveOnIterateOrderM();
        } else {
            return MoveOnIterateOrderN();
        }
    }

    __aicore__ inline void Reset() {
        MATMUL_PARAM_VAR.isFirstIter_ = true;
    }

private:
    __aicore__ inline bool MoveOnFirstIterate()
    {
        auto& var = MATMUL_PARAM_VAR;
        var.isFirstIter_ = false;
        var.curM_ = 0;
        var.curN_ = 0;
        var.stepMIdx_ = 0;
        var.stepNIdx_ = 0;
        if constexpr (!IsBasicM(MM_CFG)) {
            var.curStepM_ =
                (var.mIter_ - var.curM_) > var.tiling_.GetStepM() ?
                var.tiling_.GetStepM() : (var.mIter_ - var.curM_);
        }
        if constexpr (!IsBasicN(MM_CFG)) {
            var.curStepN_ =
                (var.nIter_ - var.curN_) > var.tiling_.GetStepN() ?
                var.tiling_.GetStepN() : (var.nIter_ - var.curN_);
        }
        return true;
    }

    __aicore__ inline bool MoveOnIterateOrderM()
    {
        auto& var = MATMUL_PARAM_VAR;
        // Output along M axis
        if constexpr (IsBasicN(MM_CFG)) {
            MATMUL_MODULE(CopyCubeInA)->Reset();
            if (++var.curM_ >= var.mIter_) {
                MATMUL_MODULE(CopyCubeInB)->Reset();
                return false;
            }
        } else {
            if (++var.curN_ >= var.stepNIdx_ + var.curStepN_) {
                MATMUL_MODULE(CopyCubeInA)->Reset();
                var.curN_ = var.stepNIdx_;
                if (++var.curM_ >= var.mIter_) {
                    MATMUL_MODULE(CopyCubeInB)->Reset();
                    var.curM_ = 0;
                    var.stepNIdx_ += var.curStepN_;
                    if (var.stepNIdx_ >= var.nIter_) {
                        return false;
                    }
                    var.curN_ = var.stepNIdx_;
                    var.curStepN_ =
                        (var.nIter_ - var.curN_) > var.tiling_.GetStepN() ?
                        var.tiling_.GetStepN() : (var.nIter_ - var.curN_);
                }
            }
        }
        return true;
    }

    __aicore__ inline bool MoveOnIterateOrderN()
    {
        auto& var = MATMUL_PARAM_VAR;
        if constexpr (IsBasicM(MM_CFG)) {
            MATMUL_MODULE(CopyCubeInB)->Reset();
            if (++var.curN_ >= var.nIter_) {
                MATMUL_MODULE(CopyCubeInA)->Reset();
                return false;
            }
        } else {
            if (++var.curM_ >= var.stepMIdx_ + var.curStepM_) {
                MATMUL_MODULE(CopyCubeInB)->Reset();
                var.curM_ = var.stepMIdx_;
                if (++var.curN_ >= var.nIter_) {
                    MATMUL_MODULE(CopyCubeInA)->Reset();
                    var.curN_ = 0;
                    var.stepMIdx_ += var.curStepM_;
                    if (var.stepMIdx_ >= var.mIter_) {
                        return false;
                    }
                    var.curM_ = var.stepMIdx_;
                    var.curStepM_ =
                        (var.mIter_ - var.curM_) > var.tiling_.GetStepM() ?
                        var.tiling_.GetStepM() : (var.mIter_ - var.curM_);
                }
            }
        }
        return true;
    }
};

}  // namespace Detail
}  // namespace Impl
}  // namespace AscendC
#endif
